/*=========================================================================
 *
 *  Copyright Insight Software Consortium
 *
 *  Licensed under the Apache License, Version 2.0 (the "License");
 *  you may not use this file except in compliance with the License.
 *  You may obtain a copy of the License at
 *
 *         http://www.apache.org/licenses/LICENSE-2.0.txt
 *
 *  Unless required by applicable law or agreed to in writing, software
 *  distributed under the License is distributed on an "AS IS" BASIS,
 *  WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 *  See the License for the specific language governing permissions and
 *  limitations under the License.
 *
 *=========================================================================*/
#ifndef itkKdTreeGenerator_hxx
#define itkKdTreeGenerator_hxx

#include  "itkKdTreeGenerator.h"

namespace itk
{
namespace Statistics
{
template< typename TSample >
KdTreeGenerator< TSample >
::KdTreeGenerator()
{
  m_SourceSample = ITK_NULLPTR;
  m_BucketSize = 16;
  m_Subsample = SubsampleType::New();
  m_MeasurementVectorSize = 0;
}

template< typename TSample >
void
KdTreeGenerator< TSample >
::PrintSelf(std::ostream & os, Indent indent) const
{
  Superclass::PrintSelf(os, indent);

  os << indent << "Source Sample: ";
  if ( m_SourceSample != ITK_NULLPTR )
    {
    os << m_SourceSample << std::endl;
    }
  else
    {
    os << "not set." << std::endl;
    }

  os << indent << "Bucket Size: " << m_BucketSize << std::endl;
  os << indent << "MeasurementVectorSize: "
     << m_MeasurementVectorSize << std::endl;
}

template< typename TSample >
void
KdTreeGenerator< TSample >
::SetSample(TSample *sample)
{
  m_SourceSample = sample;
  m_Subsample->SetSample(sample);
  m_Subsample->InitializeWithAllInstances();
  m_MeasurementVectorSize = sample->GetMeasurementVectorSize();
  NumericTraits<MeasurementVectorType>::SetLength(m_TempLowerBound, m_MeasurementVectorSize);
  NumericTraits<MeasurementVectorType>::SetLength(m_TempUpperBound, m_MeasurementVectorSize);
  NumericTraits<MeasurementVectorType>::SetLength(m_TempMean, m_MeasurementVectorSize);
}

template< typename TSample >
void
KdTreeGenerator< TSample >
::SetBucketSize(unsigned int size)
{
  m_BucketSize = size;
}

template< typename TSample >
void
KdTreeGenerator< TSample >
::GenerateData()
{
  if ( m_SourceSample == ITK_NULLPTR )
    {
    return;
    }

  if ( m_Tree.IsNull() )
    {
    m_Tree = KdTreeType::New();
    m_Tree->SetSample(m_SourceSample);
    m_Tree->SetBucketSize(m_BucketSize);
    }

  SubsamplePointer subsample = this->GetSubsample();

  // Sanity check. Verify that the subsample has measurement vectors of the
  // same length as the sample generated by the tree.
  if ( this->GetMeasurementVectorSize() != subsample->GetMeasurementVectorSize() )
    {
    itkExceptionMacro(<< "Measurement Vector Length mismatch");
    }

  MeasurementVectorType lowerBound;
  NumericTraits<MeasurementVectorType>::SetLength(lowerBound, m_MeasurementVectorSize);
  MeasurementVectorType upperBound;
  NumericTraits<MeasurementVectorType>::SetLength(upperBound, m_MeasurementVectorSize);

  for ( unsigned int d = 0; d < m_MeasurementVectorSize; d++ )
    {
    lowerBound[d] = NumericTraits< MeasurementType >::NonpositiveMin();
    upperBound[d] = NumericTraits< MeasurementType >::max();
    }

  KdTreeNodeType *root =
    this->GenerateTreeLoop(0, m_Subsample->Size(), lowerBound, upperBound, 0);
  m_Tree->SetRoot(root);
}

template< typename TSample >
inline typename KdTreeGenerator< TSample >::KdTreeNodeType *
KdTreeGenerator< TSample >
::GenerateNonterminalNode(unsigned int beginIndex,
                          unsigned int endIndex,
                          MeasurementVectorType & lowerBound,
                          MeasurementVectorType & upperBound,
                          unsigned int level)
{
  typedef typename KdTreeType::KdTreeNodeType NodeType;
  MeasurementType dimensionLowerBound;
  MeasurementType dimensionUpperBound;
  MeasurementType partitionValue;
  unsigned int    partitionDimension = 0;
  unsigned int    i;
  MeasurementType spread;
  MeasurementType maxSpread;
  unsigned int    medianIndex;

  SubsamplePointer subsample = this->GetSubsample();

  // find most widely spread dimension
  Algorithm::FindSampleBoundAndMean< SubsampleType >(subsample,
                                                     beginIndex, endIndex,
                                                     m_TempLowerBound, m_TempUpperBound,
                                                     m_TempMean);

  maxSpread = NumericTraits< MeasurementType >::NonpositiveMin();
  for ( i = 0; i < m_MeasurementVectorSize; i++ )
    {
    spread = m_TempUpperBound[i] - m_TempLowerBound[i];
    if ( spread >= maxSpread )
      {
      maxSpread = spread;
      partitionDimension = i;
      }
    }

  medianIndex = ( endIndex - beginIndex ) / 2;

  //
  // Find the medial element by using the NthElement function
  // based on the STL implementation of the QuickSelect algorithm.
  //
  partitionValue =
    Algorithm::NthElement< SubsampleType >(m_Subsample,
                                           partitionDimension,
                                           beginIndex, endIndex,
                                           medianIndex);

  medianIndex += beginIndex;

  // save bounds for cutting dimension
  dimensionLowerBound = lowerBound[partitionDimension];
  dimensionUpperBound = upperBound[partitionDimension];

  upperBound[partitionDimension] = partitionValue;
  const unsigned int beginLeftIndex = beginIndex;
  const unsigned int endLeftIndex   = medianIndex;
  NodeType *         left = GenerateTreeLoop(beginLeftIndex, endLeftIndex, lowerBound, upperBound, level + 1);
  upperBound[partitionDimension] = dimensionUpperBound;

  lowerBound[partitionDimension] = partitionValue;
  const unsigned int beginRightIndex = medianIndex + 1;
  const unsigned int endRightIndex   = endIndex;
  NodeType *         right = GenerateTreeLoop(beginRightIndex, endRightIndex, lowerBound, upperBound, level + 1);
  lowerBound[partitionDimension] = dimensionLowerBound;

  typedef KdTreeNonterminalNode< TSample > KdTreeNonterminalNodeType;

  KdTreeNonterminalNodeType *nonTerminalNode =
    new KdTreeNonterminalNodeType(partitionDimension,
                                  partitionValue,
                                  left,
                                  right);

  nonTerminalNode->AddInstanceIdentifier(
    subsample->GetInstanceIdentifier(medianIndex) );

  return nonTerminalNode;
}

template< typename TSample >
inline typename KdTreeGenerator< TSample >::KdTreeNodeType *
KdTreeGenerator< TSample >
::GenerateTreeLoop(unsigned int beginIndex,
                   unsigned int endIndex,
                   MeasurementVectorType & lowerBound,
                   MeasurementVectorType & upperBound,
                   unsigned int level)
{
  if ( endIndex - beginIndex <= m_BucketSize )
    {
    // numberOfInstances small, make a terminal node
    if ( endIndex == beginIndex )
      {
      // return the pointer to empty terminal node
      return m_Tree->GetEmptyTerminalNode();
      }
    else
      {
      KdTreeTerminalNode< TSample > *ptr =
        new KdTreeTerminalNode< TSample >();

      for ( unsigned int j = beginIndex; j < endIndex; j++ )
        {
        ptr->AddInstanceIdentifier(
          this->GetSubsample()->GetInstanceIdentifier(j) );
        }

      // return a terminal node
      return ptr;
      }
    }
  else
    {
    return this->GenerateNonterminalNode(beginIndex, endIndex,
                                         lowerBound, upperBound, level + 1);
    }
}
} // end of namespace Statistics
} // end of namespace itk

#endif
